package fun.ticsmyc.rpc.client.transport.netty;

import fun.ticsmyc.rpc.client.transport.netty.codec.RequestEncoder;
import fun.ticsmyc.rpc.client.transport.netty.codec.ResponseDecoder;
import fun.ticsmyc.rpc.client.transport.netty.handler.NettyClientHandler;
import fun.ticsmyc.rpc.common.enumeration.InitializeError;
import fun.ticsmyc.rpc.common.enumeration.RpcError;
import fun.ticsmyc.rpc.common.exception.InitializeException;
import fun.ticsmyc.rpc.common.exception.RpcException;
import fun.ticsmyc.rpc.common.serializer.Serializer;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * 用于生产RpcRequestSender
 *
 * 使用前，需要先调用RpcRequestSenderFactory.init(CommonSerializer serializer)指定序列化器。
 * 否则抛出异常
 * @author Ticsmyc
 * @date 2020-11-02 10:11
 */
public class RpcRequestSenderFactory {

    private static final Logger logger = LoggerFactory.getLogger(RpcRequestSenderFactory.class);
    private static volatile Bootstrap bootstrap;
    private static NioEventLoopGroup group;
    private static Map<String,RpcRequestSender> rpcRequestSenderCache;

    public static RpcRequestSender getRpcRequestSender(InetSocketAddress inetSocketAddress){
        if(bootstrap == null){
            throw new InitializeException(InitializeError.RPC_REQUEST_SENDER_FACTORY_NOT_INIT);
        }
        String key = inetSocketAddress.toString();
        if(rpcRequestSenderCache.containsKey(key)){
            RpcRequestSender rpcRequestSender = rpcRequestSenderCache.get(key);
            if(rpcRequestSender.isAlive()){
                logger.debug("复用之前创建好的RpcRequestSender :{}",inetSocketAddress);
                return rpcRequestSender;
            }else{
                rpcRequestSenderCache.remove(key);
            }
        }
        RpcRequestSender rpcRequestSender = new RpcRequestSender(bootstrap, inetSocketAddress);
        if(rpcRequestSender.isAlive()){
            rpcRequestSenderCache.put(key,rpcRequestSender);
            return rpcRequestSender;
        }else{
            throw new RpcException(RpcError.CONNECT_ERROR);
        }

    }

    private static void close(){
        bootstrap=null;
        if(group!= null){
            group.shutdownGracefully();
        }
    }

    public static void init(Serializer serializer){
        rpcRequestSenderCache = new ConcurrentHashMap<>();

        group = new NioEventLoopGroup();
        bootstrap = new Bootstrap();
        bootstrap.group(group)
                .channel(NioSocketChannel.class)
                //连接超时时间
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS,5000)
                //开启TCP底层心跳机制
                .option(ChannelOption.SO_KEEPALIVE, true)
                //TCP默认开启了 Nagle 算法，该算法的作用是尽可能的发送大数据快，减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
                .option(ChannelOption.TCP_NODELAY, true)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        //添加 编码器、解码器
                        pipeline.addLast(new RequestEncoder(serializer));
                        pipeline.addLast(new ResponseDecoder());
                        //添加心跳机制 ，5s没有写操作就发一个心跳
                        pipeline.addLast(new IdleStateHandler(0,5,0, TimeUnit.SECONDS));
                        //添加业务handler
                        pipeline.addLast(new NettyClientHandler());
                    }
                });
    }

}
